{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e1c90e9d-27a9-4d02-b5ab-b9eed4d77aba",
   "metadata": {},
   "source": [
    "# Google Calendar Assistant with with Llama 3.2 3B Tool Calling\n",
    "\n",
    "This notebook showcases how to go about building a digital assistant to schedule meetings with the Llama 3.2 3B model. The core concepts used to implement this are Prompt Engineering and Tool Calling. This demo shows how Llama can be used to interact with 3rd party apps like Google Contacts & Google Calendar and schedule a meeting requested by the user. Even though we are using prompt engineering to achieve this, the approach described doesn't degrade the model's ability to answer general queries. This approach can extended to perform other tasks in a similar manner without affecting the quality of other tasks\n",
    "\n",
    "\n",
    "## Approach\n",
    "\n",
    "Instead of using a complex system prompt with multiple conditions & expecting Llama to perform various tasks accurately out of the box, the approach here is to treat this as a 2 step process\n",
    "- Determine user intent - Task classification\n",
    "- Take action for the specific task using Tool Calling\n",
    "\n",
    "\n",
    "\n",
    "In the diagram shown below,\n",
    "- system prompt 1 determines the classfication of the query\n",
    "- In steps 2 & 3, we classify the task being requested.\n",
    "- system prompt 2 is chosen based on the classification result\n",
    "- Steps 4 & 5 implement the classified task.\n",
    "- For the sake of demo, we show 2 classes: General & Meeting\n",
    "\n",
    "![Tool Calling Flow Diagram](./assets/flow_diagram.png)\n",
    "\n",
    "Both these tasks have a specific prompt. We use the same model with different system prompts depending on the classification result.\n",
    "Additionally, this demo also showcases how Llama can be used to do double tool calling with 1 prompt. In the case of Meeting, Llama returns 2 function calls in Step 5\n",
    "```\n",
    "<function=google_contact>{{\"name\": \"John Constantine\"}}</function>\n",
    "<function=google_calendar>{{\"date\": \"Mar 31 \", \"time\": \"5:30 pm\", \"attendees\": \"John Constantine\"}}</function>\n",
    "```\n",
    "\n",
    "Actions based on tool calling output\n",
    "- The google_contact function call returned by the model is used to call [People API](https://developers.google.com/people) to look up the email address of the person of interest\n",
    "- The email address from the previous step is used to call [Calendar API](https://developers.google.com/calendar) along with the other information in the google_calendar toolcall output returned by the model \n",
    "\n",
    "The end result is that a google meeting is scheduled with the person of interest at the date & time specified\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ca337cc-7908-4daf-be57-66eeb4f8f703",
   "metadata": {},
   "source": [
    "## Load Llama 3.2-3B-Instruct model\n",
    "\n",
    "This demo also intends to show that tool calling can be done effectively even with the 3B model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bb085f6c-f069-48ef-8085-32f8a301b090",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "efc10dfe3b754f5eb17271d10da9e9d1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "model_id = \"meta-llama/Llama-3.2-3B-Instruct\"\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_id, device_map=\"auto\", torch_dtype=torch.bfloat16)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53427131-4311-4637-bac8-3f65073dbe23",
   "metadata": {},
   "source": [
    "## Define functions to access People (`google_contact`) and Google Calendar (`google_calendar`) API\n",
    "\n",
    "**Note!!!!**\n",
    "Accessing Google APIs require you to first get credentials (This is a one time process) \n",
    "\n",
    "Store your credentials in `credentials.json` Please refer to the [instructions](https://developers.google.com/workspace/guides/create-credentials) here to get the credentials.\n",
    "\n",
    "I followed the the steps in [service account](https://developers.google.com/identity/protocols/oauth2/service-account#creatinganaccount) to get my credentials.\n",
    "\n",
    "Accessing Google APIs also require you to get **authentication token** in addition to the credentials. The functions defined below `google_contact` and `google_calendar` also include the logic to generate & refresh the authentication token.\n",
    "\n",
    "To get the authentication token for using the People & Google Calendar API, we need a runnable browser available on the machine where you execute these functions for the **FIRST TIME** only. You will need to authenticate using your browser for the first time. Executing these functions will generate a `token_contacts.json` & `token_calendar.json`\n",
    "\n",
    "For subsequent calls, you don't need a runnable browser"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ea9af60-6a0c-4a01-adf3-d620f5871d86",
   "metadata": {},
   "source": [
    "#### Install required libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd102901-40bc-4d3c-8726-1dd64164086c",
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install google-api-python-client google-auth-oauthlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0e212a74-4c9f-424c-a39b-9c5dc7fd45d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path\n",
    "\n",
    "from google.auth.transport.requests import Request\n",
    "from google.oauth2.credentials import Credentials\n",
    "from google_auth_oauthlib.flow import InstalledAppFlow\n",
    "from googleapiclient.discovery import build\n",
    "from googleapiclient.errors import HttpError\n",
    "from datetime import datetime, timedelta\n",
    "\n",
    "\n",
    "SCOPES = [\"https://www.googleapis.com/auth/contacts.readonly\"]\n",
    "\n",
    "TOKEN = \"token_contacts.json\"\n",
    "\n",
    "\n",
    "def google_contact(name):\n",
    "  \"\"\"Returns the email address in Google contacts for the given name.\n",
    "  \"\"\"\n",
    "  creds = None\n",
    "\n",
    "  # The file token_contacts.json stores the user's access and refresh tokens, and is\n",
    "  # created automatically when the authorization flow completes for the first\n",
    "  # time.\n",
    "  if os.path.exists(TOKEN):\n",
    "    creds = Credentials.from_authorized_user_file(TOKEN, SCOPES)\n",
    "\n",
    "  # If there are no (valid) credentials available, let the user log in.\n",
    "  if not creds or not creds.valid:\n",
    "    if creds and creds.expired and creds.refresh_token:\n",
    "      creds.refresh(Request())\n",
    "    else:\n",
    "      flow = InstalledAppFlow.from_client_secrets_file(\n",
    "          \"credentials.json\", SCOPES\n",
    "      )\n",
    "      creds = flow.run_local_server(port=0)\n",
    "    # Save the credentials for the next run\n",
    "    with open(TOKEN, \"w\") as token:\n",
    "      token.write(creds.to_json())\n",
    "\n",
    "  try:\n",
    "    service = build(\"people\", \"v1\", credentials=creds)\n",
    "\n",
    "    # Call the People API\n",
    "    results = (\n",
    "        service.people()\n",
    "        .connections()\n",
    "        .list(\n",
    "            resourceName=\"people/me\",\n",
    "            pageSize=10,\n",
    "            personFields=\"names,emailAddresses\",\n",
    "        )\n",
    "        .execute()\n",
    "    )\n",
    "    connections = results.get(\"connections\", [])\n",
    "\n",
    "    # Build a dictionary of name & email address\n",
    "    db = {}\n",
    "    for person in connections:\n",
    "      names = person.get(\"names\", [])\n",
    "      if names:\n",
    "        n = names[0].get(\"displayName\")\n",
    "        \n",
    "      email = person.get(\"emailAddresses\", [])\n",
    "      db[n] = email[0].get('value')\n",
    "    \n",
    "    return db[name]\n",
    "  except HttpError as err:\n",
    "    print(err)\n",
    "\n",
    "\n",
    "CALSCOPES = [\"https://www.googleapis.com/auth/calendar\"]\n",
    "\n",
    "CALTOKEN = \"token_calendar.json\"\n",
    "\n",
    "def google_calendar(date, time, attendees):\n",
    "  \"\"\"Creates a meeting invite using Google Calendar API.\n",
    "  \"\"\"\n",
    "  creds = None\n",
    "  # The file token_calendar.json stores the user's access and refresh tokens, and is\n",
    "  # created automatically when the authorization flow completes for the first\n",
    "  # time.\n",
    "  if os.path.exists(CALTOKEN):\n",
    "    creds = Credentials.from_authorized_user_file(CALTOKEN, CALSCOPES)\n",
    "        \n",
    "  # If there are no (valid) credentials available, let the user log in.\n",
    "  if not creds or not creds.valid:\n",
    "    if creds and creds.expired and creds.refresh_token:\n",
    "      creds.refresh(Request())\n",
    "    else:\n",
    "      flow = InstalledAppFlow.from_client_secrets_file(\n",
    "          \"credentials.json\", CALSCOPES\n",
    "      )\n",
    "      creds = flow.run_local_server(port=0)\n",
    "    # Save the credentials for the next run\n",
    "    with open(CALTOKEN, \"w\") as token:\n",
    "      token.write(creds.to_json())\n",
    "\n",
    "  name, email = attendees[\"name\"], attendees[\"email\"]\n",
    "  time_str = date + ', 2025 ' + time\n",
    "  # Convert the time string to a datetime object\n",
    "  dt_obj = datetime.strptime(time_str, '%b %d, %Y %I:%M %p')\n",
    "  dt_obj_end =  dt_obj + timedelta(minutes=30)\n",
    "  event = {\n",
    "    'summary': 'Meeting: Ankith | ' + name,\n",
    "    'location': 'MPK 21',\n",
    "    'description': 'Sync up meeting',\n",
    "    'start': {\n",
    "      'dateTime': dt_obj.strftime('%Y-%m-%dT%H:%M:%S'),\n",
    "      'timeZone': 'America/Los_Angeles',\n",
    "    },\n",
    "    'end': {\n",
    "      f'dateTime': dt_obj_end.strftime('%Y-%m-%dT%H:%M:%S'),\n",
    "      'timeZone': 'America/Los_Angeles',\n",
    "    },\n",
    "    'attendees': [\n",
    "      {'email': email},\n",
    "    ],\n",
    "    'reminders': {\n",
    "      'useDefault': False,\n",
    "      'overrides': [\n",
    "        {'method': 'email', 'minutes': 24 * 60},\n",
    "        {'method': 'popup', 'minutes': 10},\n",
    "      ],\n",
    "    },\n",
    "  }\n",
    "  \n",
    "  service = build(\"calendar\", \"v3\", credentials=creds)\n",
    "\n",
    "  # Create a meeting invite\n",
    "  event = service.events().insert(calendarId='primary', body=event).execute()\n",
    "  \n",
    "  result = f\"\\nEvent created with {name} with email {email} on {date} at {time}: \\n {event.get('htmlLink')}\"\n",
    "  return result\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1106b0ba-9e0e-41bd-a6cd-d0d069a7fc4a",
   "metadata": {},
   "source": [
    "### Function to process tool calling output and call Google APIs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5df96557-8beb-43b9-ac99-c1218a686676",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "functions = {\n",
    "  \"google_contact\": google_contact,\n",
    "  \"google_calendar\": google_calendar\n",
    "}\n",
    "\n",
    "\n",
    "def process_llama_output(llama_output):\n",
    "  \"\"\"\n",
    "  Function to process the tool calling output & make the actual Google API call\n",
    "  \"\"\"\n",
    "\n",
    "  result = \"\"\n",
    "  prefix = \"<function=\"\n",
    "  suffix = \"</function>\"\n",
    "  start, N = 0, len(llama_output)\n",
    "  count = 0\n",
    "  email = None\n",
    "\n",
    "  # In this example, we are expecting Llama to produce 2 tool calling outputs\n",
    "  while count < 2:\n",
    "    begin, end = llama_output[start:].find(prefix) + len(prefix) + start, llama_output[start:].find(suffix) + start\n",
    "    func_name, params = llama_output[begin:end-1].split(\">{\")\n",
    "    end += len(suffix)\n",
    "    start = end\n",
    "    count += 1\n",
    "    params = json.loads(params.strip())\n",
    "    \n",
    "    if not email and func_name in functions:\n",
    "      email = functions[func_name](**params)\n",
    "    \n",
    "    elif email and func_name in functions:\n",
    "      attendees = {}\n",
    "      attendees[\"name\"] = params[\"attendees\"]\n",
    "      attendees[\"email\"] = email\n",
    "      params[\"attendees\"] = attendees\n",
    "\n",
    "      result += \"\\n-----------------------------------------------------\\n\"\n",
    "      result += functions[func_name](**params)\n",
    "      return result\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a94cecf8-239a-4d57-9f73-2281786ad6f2",
   "metadata": {},
   "source": [
    "### Functions for calling Llama for the 2 tasks: general & setup meeting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a75b7fac-7553-40e5-aedd-67ebdc98cce7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def general_query(user_prompt, result):\n",
    "  \"\"\"\n",
    "  Function to call Llama for general queries\n",
    "  \"\"\"\n",
    "  \n",
    "  result += \"\\n-----------------------------------------------------\\n\"\n",
    "  result += \"SECOND PASS\"\n",
    "  result += \"\\n-----------------------------------------------------\\n\"\n",
    "\n",
    "  max_new_tokens = 100\n",
    "  SYSTEM_PROMPT1 = f\"\"\"\n",
    "\n",
    "  You are a helpful assistant. Answer the query in {max_new_tokens} words\n",
    "  \"\"\"\n",
    "\n",
    "  prompt = SYSTEM_PROMPT1 + f\"\\n\\nQuestion: {user_prompt} \\n\" + \"\\n Answer:\\n\"\n",
    "  input_prompt_length = len(prompt)\n",
    "\n",
    "  input_ids = tokenizer(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "  output = model.generate(**input_ids, max_new_tokens=max_new_tokens)\n",
    "\n",
    "  llama_output = tokenizer.decode(output[0], skip_special_tokens=True)\n",
    "\n",
    "\n",
    "  result += llama_output\n",
    "\n",
    "  return result\n",
    "\n",
    "\n",
    "def setup_meeting(user_prompt, result):\n",
    "  \"\"\"\n",
    "  Function to call Llama for setting up meeting using Google APIs\n",
    "  \"\"\"\n",
    "  \n",
    "  result += \"\\n-----------------------------------------------------\\n\"\n",
    "  result += \"SECOND PASS\"\n",
    "  result += \"\\n-----------------------------------------------------\\n\"\n",
    "\n",
    "  max_new_tokens = 100\n",
    "  SYSTEM_PROMPT2 = \"\"\"\n",
    "\n",
    "  You are a helpful assistant.\n",
    "  Here are the tools you have access to: google_contact and google_calendar which you can use for ONLY setting up meetings:\n",
    "  Use the function 'google_contact' to: look up a name and get the email address\n",
    "  {\n",
    "    \"name\": \"google_contact\",\n",
    "    \"description\": \"Get email address\",\n",
    "    \"parameters\": {\n",
    "      \"name\": {\n",
    "        \"param_type\": \"string\",\n",
    "        \"description\": \"name\",\n",
    "        \"required\": true\n",
    "      }\n",
    "    }\n",
    "  }\n",
    "  Use the function 'google_calendar' to: schedule a meeting\n",
    "  {\n",
    "    \"name\": \"google_calendar\",\n",
    "    \"description\": \"Schedule a meeting\",\n",
    "    \"parameters\": {\n",
    "      \"date\": {\n",
    "        \"param_type\": \"string\",\n",
    "        \"description\": \"date\",\n",
    "        \"required\": true\n",
    "      },\n",
    "      \"time\"\": {\n",
    "        \"param_type\": \"string\",\n",
    "        \"description\": \"time\",\n",
    "        \"required\": true\n",
    "      },\n",
    "      \"attendees\"\": {\n",
    "        \"param_type\": \"string\",\n",
    "        \"description\": \"name\",\n",
    "        \"required\": true\n",
    "      }\n",
    "    }\n",
    "  }\n",
    "  Identify the name in the query, lookup the name using 'google_contact' to find the email address of the contact \n",
    "  and then use 'google_calndar' to schedule a meeeting with this person.\n",
    "  DO NOT reply with any other reasoning or exaplanation\n",
    "  ONLY reply in the following format with no prefix or suffix:\n",
    "  <function=function_name_being_used>{{\"example_name\": \"example_value\"}}</function>\n",
    "  Reminder:\n",
    "  - ONLY return the function call and don't reply anything else\n",
    "  - DO NOT give any explanation\n",
    "  - I REPEAT ONLY return function calls\n",
    "  - Function calls MUST follow the specified format, start with <function= and end with </function>\n",
    "  - Required parameters MUST be specified\n",
    "  - First return google_contact and then google_calendar\n",
    "  \"\"\"\n",
    "\n",
    "\n",
    "  prompt = SYSTEM_PROMPT2 + f\"\\n\\nQuestion: {user_prompt} \\n\"\n",
    "\n",
    "  input_prompt_length = len(prompt)\n",
    "\n",
    "  input_ids = tokenizer(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "  output = model.generate(**input_ids, max_new_tokens=max_new_tokens)\n",
    "\n",
    "  llama_output = tokenizer.decode(output[0], skip_special_tokens=True)\n",
    "\n",
    "\n",
    "  result += llama_output\n",
    "\n",
    "\n",
    "  result += \"\\n\\n\\n\"\n",
    "  result += process_llama_output(llama_output[input_prompt_length:])\n",
    "  return result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1260bc4-9a04-47b6-be29-f82d5c949642",
   "metadata": {},
   "source": [
    "## Demo with a general query\n",
    "\n",
    "To help explain the final output, we also print the system prompts being used in both the calls to Llama"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "48a57e28-82cf-4f8f-b9ac-2b357f7aea7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_prompt = \"Tell me about Paris\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f3c35189-4152-45a8-81ab-386466d6284c",
   "metadata": {},
   "outputs": [],
   "source": [
    "SYSTEM_PROMPT = \"\"\"\n",
    "\n",
    "  Classify the given prompt\n",
    "  A message can be classified as one of the following categories: meeting, general.\n",
    "\n",
    "  Examples:\n",
    "  - meeting: \"Can you schedule a meeting with Ankith on Dec 1 at 1 pm.\"\n",
    "  - general: \"Tell me about San Francisco\"\n",
    "\n",
    "  Reminder:\n",
    "  - ONLY return the classification label and NOTHING else\n",
    "  - DO NOT give any explanation or examples\n",
    "  \"\"\"\n",
    "\n",
    "prompt = SYSTEM_PROMPT + f\"\\n\\nQuestion: {user_prompt} \\n\" + \"\\n Answer:\\n\"\n",
    "\n",
    "input_prompt_length = len(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "58801f79-08eb-49f2-bb35-aff1e026e41c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n",
      "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "-----------------------------------------------------\n",
      "FIRST PASS\n",
      "-----------------------------------------------------\n",
      "\n",
      "\n",
      "  Classify the given prompt\n",
      "  A message can be classified as one of the following categories: meeting, general.\n",
      "\n",
      "  Examples:\n",
      "  - meeting: \"Can you schedule a meeting with Ankith on Dec 1 at 1 pm.\"\n",
      "  - general: \"Tell me about San Francisco\"\n",
      "\n",
      "  Reminder:\n",
      "  - ONLY return the classification label and NOTHING else\n",
      "  - DO NOT give any explanation or examples\n",
      "  \n",
      "\n",
      "Question: Tell me about Paris \n",
      "\n",
      " Answer:\n",
      " general \n",
      "\n",
      "\n",
      "\n",
      "-----------------------------------------------------\n",
      "SECOND PASS\n",
      "-----------------------------------------------------\n",
      "\n",
      "\n",
      "  You are a helpful assistant. Answer the query in 100 words\n",
      "  \n",
      "\n",
      "Question: Tell me about Paris \n",
      "\n",
      " Answer:\n",
      "Paris, the City of Light, is the capital of France, known for its stunning architecture, art, fashion, and romance. The city is home to iconic landmarks like the Eiffel Tower, Notre-Dame Cathedral, and the Louvre Museum, which houses the Mona Lisa. The city's charming streets, cafes, and gardens offer a unique blend of history, culture, and beauty. Visitors can stroll along the Seine River, explore the Latin Quarter, and indulge in French cuisine, wine\n"
     ]
    }
   ],
   "source": [
    "\n",
    "result = \"\\n-----------------------------------------------------\\n\"\n",
    "result += \"FIRST PASS\"\n",
    "result += \"\\n-----------------------------------------------------\\n\"\n",
    "\n",
    "input_ids = tokenizer(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "output = model.generate(**input_ids, max_new_tokens=2)\n",
    "\n",
    "llama_output = tokenizer.decode(output[0], skip_special_tokens=True)\n",
    "\n",
    "result += llama_output\n",
    "result += \"\\n\"\n",
    "\n",
    "\n",
    "classification = llama_output[input_prompt_length:].strip()\n",
    "\n",
    "\n",
    "if  \"meeting\" in classification:\n",
    "    result = setup_meeting(prompt, result)\n",
    "elif \"general\" in classification:\n",
    "    result = general_query(user_prompt, result)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "93a74c4b-a918-43a2-8099-a4f0efb8902f",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_prompt = \"Can you describe how an internal combustion engine works \"\n",
    "\n",
    "prompt = SYSTEM_PROMPT + f\"\\n\\nQuestion: {user_prompt} \\n\" + \"\\n Answer:\\n\"\n",
    "\n",
    "input_prompt_length = len(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b87f590d-d697-4f8f-af8d-5099c59019f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n",
      "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "-----------------------------------------------------\n",
      "FIRST PASS\n",
      "-----------------------------------------------------\n",
      "\n",
      "\n",
      "  Classify the given prompt\n",
      "  A message can be classified as one of the following categories: meeting, general.\n",
      "\n",
      "  Examples:\n",
      "  - meeting: \"Can you schedule a meeting with Ankith on Dec 1 at 1 pm.\"\n",
      "  - general: \"Tell me about San Francisco\"\n",
      "\n",
      "  Reminder:\n",
      "  - ONLY return the classification label and NOTHING else\n",
      "  - DO NOT give any explanation or examples\n",
      "  \n",
      "\n",
      "Question: Can you describe how an internal combustion engine works  \n",
      "\n",
      " Answer:\n",
      "general\n",
      "\n",
      "\n",
      "\n",
      "-----------------------------------------------------\n",
      "SECOND PASS\n",
      "-----------------------------------------------------\n",
      "\n",
      "\n",
      "  You are a helpful assistant. Answer the query in 100 words\n",
      "  \n",
      "\n",
      "Question: Can you describe how an internal combustion engine works  \n",
      "\n",
      " Answer:\n",
      "An internal combustion engine is a type of heat engine that generates power by burning fuel, typically gasoline or diesel, inside a combustion chamber within the engine. The engine consists of several major components, including cylinders, pistons, crankshafts, and valves. Here's a simplified overview of the process:\n",
      "\n",
      "1. Air and fuel are drawn into the cylinders through the intake valves.\n",
      "2. The mixture is compressed by the pistons, creating a high-pressure mixture.\n",
      "3. The spark plug ignites the\n"
     ]
    }
   ],
   "source": [
    "\n",
    "result = \"\\n-----------------------------------------------------\\n\"\n",
    "result += \"FIRST PASS\"\n",
    "result += \"\\n-----------------------------------------------------\\n\"\n",
    "\n",
    "input_ids = tokenizer(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "output = model.generate(**input_ids, max_new_tokens=2)\n",
    "\n",
    "llama_output = tokenizer.decode(output[0], skip_special_tokens=True)\n",
    "\n",
    "result += llama_output\n",
    "result += \"\\n\"\n",
    "\n",
    "\n",
    "classification = llama_output[input_prompt_length:].strip()\n",
    "\n",
    "\n",
    "if  \"meeting\" in classification:\n",
    "    result = setup_meeting(prompt, result)\n",
    "elif \"general\" in classification:\n",
    "    result = general_query(user_prompt, result)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c022c93e-4886-40f4-b3a3-3384896f5d7c",
   "metadata": {},
   "source": [
    "## Demo with a set up meeting query, which uses tool calling & calls Google APIs\n",
    "\n",
    "To help explain the final output, we also print the system prompts being used in both the calls to Llama"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "437a1467-ff3c-4559-b778-fe725045b0ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_prompt = \"Schedule a meeting with John Constantine on Mar 31 at 5:30 pm\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "e346c0de-1fd4-43a5-b067-6347a4a87f83",
   "metadata": {},
   "outputs": [],
   "source": [
    "SYSTEM_PROMPT = \"\"\"\n",
    "\n",
    "  Classify the given prompt\n",
    "  A message can be classified as one of the following categories: meeting, general.\n",
    "\n",
    "  Examples:\n",
    "  - meeting: \"Can you schedule a meeting with Ankith on Dec 1 at 1 pm.\"\n",
    "  - general: \"Tell me about San Francisco\"\n",
    "\n",
    "  Reminder:\n",
    "  - ONLY return the classification label and NOTHING else\n",
    "  - DO NOT give any explanation or examples\n",
    "  \"\"\"\n",
    "\n",
    "prompt = SYSTEM_PROMPT + f\"\\n\\nQuestion: {user_prompt} \\n\" + \"\\n Answer:\\n\"\n",
    "\n",
    "input_prompt_length = len(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "fe5c51e0-facb-4889-9bea-0496db6b15d7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n",
      "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "-----------------------------------------------------\n",
      "FIRST PASS\n",
      "-----------------------------------------------------\n",
      "\n",
      "\n",
      "  Classify the given prompt\n",
      "  A message can be classified as one of the following categories: meeting, general.\n",
      "\n",
      "  Examples:\n",
      "  - meeting: \"Can you schedule a meeting with Ankith on Dec 1 at 1 pm.\"\n",
      "  - general: \"Tell me about San Francisco\"\n",
      "\n",
      "  Reminder:\n",
      "  - ONLY return the classification label and NOTHING else\n",
      "  - DO NOT give any explanation or examples\n",
      "  \n",
      "\n",
      "Question: Schedule a meeting with John Constantine on Mar 31 at 5:30 pm \n",
      "\n",
      " Answer:\n",
      "  meeting\n",
      "\n",
      "-----------------------------------------------------\n",
      "SECOND PASS\n",
      "-----------------------------------------------------\n",
      "\n",
      "\n",
      "  You are a helpful assistant.\n",
      "  Here are the tools you have access to: google_contact and google_calendar which you can use for ONLY setting up meetings:\n",
      "  Use the function 'google_contact' to: look up a name and get the email address\n",
      "  {\n",
      "    \"name\": \"google_contact\",\n",
      "    \"description\": \"Get email address\",\n",
      "    \"parameters\": {\n",
      "      \"name\": {\n",
      "        \"param_type\": \"string\",\n",
      "        \"description\": \"name\",\n",
      "        \"required\": true\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "  Use the function 'google_calendar' to: schedule a meeting\n",
      "  {\n",
      "    \"name\": \"google_calendar\",\n",
      "    \"description\": \"Schedule a meeting\",\n",
      "    \"parameters\": {\n",
      "      \"date\": {\n",
      "        \"param_type\": \"string\",\n",
      "        \"description\": \"date\",\n",
      "        \"required\": true\n",
      "      },\n",
      "      \"time\"\": {\n",
      "        \"param_type\": \"string\",\n",
      "        \"description\": \"time\",\n",
      "        \"required\": true\n",
      "      },\n",
      "      \"attendees\"\": {\n",
      "        \"param_type\": \"string\",\n",
      "        \"description\": \"name\",\n",
      "        \"required\": true\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "  Identify the name in the query, lookup the name using 'google_contact' to find the email address of the contact \n",
      "  and then use 'google_calndar' to schedule a meeeting with this person.\n",
      "  DO NOT reply with any other reasoning or exaplanation\n",
      "  ONLY reply in the following format with no prefix or suffix:\n",
      "  <function=function_name_being_used>{{\"example_name\": \"example_value\"}}</function>\n",
      "  Reminder:\n",
      "  - ONLY return the function call and don't reply anything else\n",
      "  - DO NOT give any explanation\n",
      "  - I REPEAT ONLY return function calls\n",
      "  - Function calls MUST follow the specified format, start with <function= and end with </function>\n",
      "  - Required parameters MUST be specified\n",
      "  - First return google_contact and then google_calendar\n",
      "  \n",
      "\n",
      "Question: \n",
      "\n",
      "  Classify the given prompt\n",
      "  A message can be classified as one of the following categories: meeting, general.\n",
      "\n",
      "  Examples:\n",
      "  - meeting: \"Can you schedule a meeting with Ankith on Dec 1 at 1 pm.\"\n",
      "  - general: \"Tell me about San Francisco\"\n",
      "\n",
      "  Reminder:\n",
      "  - ONLY return the classification label and NOTHING else\n",
      "  - DO NOT give any explanation or examples\n",
      "  \n",
      "\n",
      "Question: Schedule a meeting with John Constantine on Mar 31 at 5:30 pm \n",
      "\n",
      " Answer:\n",
      " \n",
      "  <function=google_contact>{{\"name\": \"John Constantine\"}}</function>\n",
      "  <function=google_calendar>{{\"date\": \"Mar 31\", \"time\": \"5:30 pm\", \"attendees\": \"John Constantine\"}}</function>\n",
      "\n",
      "\n",
      "\n",
      "-----------------------------------------------------\n",
      "\n",
      "Event created with John Constantine with email constantine@example.com on Mar 31 at 5:30 pm: \n",
      " https://www.google.com/calendar/event?eid=MG1rMnBsbzRvaGYwdXUzc2tydnZvNGZyNWMgYW5raXRodGVzdHRvb2xAbQ\n"
     ]
    }
   ],
   "source": [
    "\n",
    "result = \"\\n-----------------------------------------------------\\n\"\n",
    "result += \"FIRST PASS\"\n",
    "result += \"\\n-----------------------------------------------------\\n\"\n",
    "\n",
    "input_ids = tokenizer(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "output = model.generate(**input_ids, max_new_tokens=2)\n",
    "\n",
    "llama_output = tokenizer.decode(output[0], skip_special_tokens=True)\n",
    "\n",
    "result += llama_output\n",
    "result += \"\\n\"\n",
    "\n",
    "\n",
    "classification = llama_output[input_prompt_length:].strip()\n",
    "\n",
    "if  \"meeting\" in classification:\n",
    "    result = setup_meeting(prompt, result)\n",
    "elif \"general\" in classification:\n",
    "    result = general_query(user_prompt, result)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f2b5788-762e-44d2-be26-b7e5bb823f68",
   "metadata": {},
   "source": [
    "## Result\n",
    "\n",
    "In your Google Calendar, you should see an invite generated as shown below\n",
    "\n",
    "<img src=\"./assets/google_calendar.png\" alt=\"Google Calendar Invite\" width=\"300\" height=\"200\">"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
